3113f8
@@ -36,7 +36,6 @@
 import org.apache.calcite.tools.RelBuilder;
 import org.apache.calcite.tools.RelBuilderFactory;
 import org.apache.calcite.util.ImmutableBitSet;
-import org.apache.calcite.util.ImmutableIntList;
 import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -72,6 +71,26 @@
private HiveSemiJoinRule(RelOptRuleOperand operand, RelBuilderFactory relBuilder
     super(operand, relBuilder, null);
   }
 
+  private RelNode buildProject(final Aggregate aggregate, RexBuilder rexBuilder, RelBuilder relBuilder) {
+    assert(!aggregate.indicator && aggregate.getAggCallList().isEmpty());
+    RelNode input = aggregate.getInput();
+    List<Integer> groupingKeys = aggregate.getGroupSet().asList();
+    List<RexNode> projects = new ArrayList<>();
+    for(Integer keys:groupingKeys) {
+      projects.add(rexBuilder.makeInputRef(input, keys.intValue()));
+    }
+    return relBuilder.push(aggregate.getInput()).project(projects).build();
+  }
+
+  private boolean needProject(final RelNode input, final RelNode aggregate) {
+    if((input instanceof HepRelVertex
+        && ((HepRelVertex)input).getCurrentRel() instanceof  Join)
+        || input.getRowType().getFieldCount() != aggregate.getRowType().getFieldCount()) {
+      return true;
+    }
+    return false;
+  }
+
   protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs,
                          RelNode topOperator, Join join, RelNode left, Aggregate aggregate) {
     LOG.debug("Matched HiveSemiJoinRule");
@@ -107,29 +126,14 @@
protected void perform(RelOptRuleCall call, ImmutableBitSet topRefs,
     for (int key : joinInfo.rightKeys) {
       newRightKeyBuilder.add(aggregateKeys.get(key));
     }
-    final ImmutableIntList newRightKeys =
-        ImmutableIntList.copyOf(newRightKeyBuilder);
-    final RelNode newRight = aggregate.getInput();
+    RelNode input = aggregate.getInput();
+    final RelNode newRight = needProject(input, aggregate) ?
+        buildProject(aggregate, rexBuilder, call.builder()) : input;
     final RexNode newCondition =
         RelOptUtil.createEquiJoinCondition(left, joinInfo.leftKeys, newRight,
-                                           newRightKeys, rexBuilder);
-
-    RelNode semi = null;
-    //HIVE-15458: we need to add a Project on top of Join since SemiJoin with Join as it's right input
-    // is not expected further down the pipeline. see jira for more details
-    if(aggregate.getInput() instanceof HepRelVertex
-        && ((HepRelVertex)aggregate.getInput()).getCurrentRel() instanceof  Join) {
-      Join rightJoin = (Join)(((HepRelVertex)aggregate.getInput()).getCurrentRel());
-      List<RexNode> projects = new ArrayList<>();
-      for(int i=0; i<rightJoin.getRowType().getFieldCount(); i++){
-        projects.add(rexBuilder.makeInputRef(rightJoin, i));
-      }
-      RelNode topProject =  call.builder().push(rightJoin).project(projects, rightJoin.getRowType().getFieldNames(),
-                                                                   true).build();
-      semi = call.builder().push(left).push(topProject).semiJoin(newCondition).build();
-    } else {
-      semi = call.builder().push(left).push(aggregate.getInput()).semiJoin(newCondition).build();
-    }
+                                           joinInfo.rightKeys, rexBuilder);
+
+    RelNode semi = call.builder().push(left).push(newRight).semiJoin(newCondition).build();
     call.transformTo(topOperator.copy(topOperator.getTraitSet(), ImmutableList.of(semi)));
   }
 
